"""
Utilities for extensions of and operations on Python collections.
"""

import io
import itertools
import types
from collections import OrderedDict
from collections.abc import (
    Callable,
    Collection,
    Generator,
    Hashable,
    Iterable,
    Iterator,
    Sequence,
    Set,
)
from dataclasses import fields, is_dataclass, replace
from enum import Enum, auto
from typing import (
    TYPE_CHECKING,
    Any,
    Literal,
    Optional,
    Union,
    cast,
    overload,
)
from unittest.mock import Mock

import pydantic
from typing_extensions import TypeAlias, TypeVar

# Quote moved to `prefect.utilities.annotations` but preserved here for compatibility
from prefect.utilities.annotations import BaseAnnotation as BaseAnnotation
from prefect.utilities.annotations import Quote as Quote
from prefect.utilities.annotations import quote as quote

if TYPE_CHECKING:
    pass


class AutoEnum(str, Enum):
    """
    An enum class that automatically generates value from variable names.

    This guards against common errors where variable names are updated but values are
    not.

    In addition, because AutoEnums inherit from `str`, they are automatically
    JSON-serializable.

    See https://docs.python.org/3/library/enum.html#using-automatic-values

    Example:
        ```python
        class MyEnum(AutoEnum):
            RED = AutoEnum.auto() # equivalent to RED = 'RED'
            BLUE = AutoEnum.auto() # equivalent to BLUE = 'BLUE'
        ```
    """

    @staticmethod
    def _generate_next_value_(name: str, *_: object, **__: object) -> str:
        return name

    @staticmethod
    def auto() -> str:
        """
        Exposes `enum.auto()` to avoid requiring a second import to use `AutoEnum`
        """
        return auto()

    def __repr__(self) -> str:
        return f"{type(self).__name__}.{self.value}"


KT = TypeVar("KT")
VT = TypeVar("VT", infer_variance=True)
VT1 = TypeVar("VT1", infer_variance=True)
VT2 = TypeVar("VT2", infer_variance=True)
R = TypeVar("R", infer_variance=True)
NestedDict: TypeAlias = dict[KT, Union[VT, "NestedDict[KT, VT]"]]
HashableT = TypeVar("HashableT", bound=Hashable)


def dict_to_flatdict(dct: NestedDict[KT, VT]) -> dict[tuple[KT, ...], VT]:
    """Converts a (nested) dictionary to a flattened representation.

    Each key of the flat dict will be a CompoundKey tuple containing the "chain of keys"
    for the corresponding value.

    Args:
        dct (dict): The dictionary to flatten

    Returns:
        A flattened dict of the same type as dct
    """

    def flatten(
        dct: NestedDict[KT, VT], _parent: tuple[KT, ...] = ()
    ) -> Iterator[tuple[tuple[KT, ...], VT]]:
        parent = _parent or ()
        for k, v in dct.items():
            k_parent = (*parent, k)
            # if v is a non-empty dict, recurse
            if isinstance(v, dict) and v:
                yield from flatten(cast(NestedDict[KT, VT], v), _parent=k_parent)
            else:
                yield (k_parent, cast(VT, v))

    type_ = cast(type[dict[tuple[KT, ...], VT]], type(dct))
    return type_(flatten(dct))


def flatdict_to_dict(dct: dict[tuple[KT, ...], VT]) -> NestedDict[KT, VT]:
    """Converts a flattened dictionary back to a nested dictionary.

    Args:
        dct (dict): The dictionary to be nested. Each key should be a tuple of keys
            as generated by `dict_to_flatdict`

    Returns
        A nested dict of the same type as dct
    """

    type_ = cast(type[NestedDict[KT, VT]], type(dct))

    def new(type_: type[NestedDict[KT, VT]] = type_) -> NestedDict[KT, VT]:
        return type_()

    result = new()
    for key_tuple, value in dct.items():
        current = result
        *prefix_keys, last_key = key_tuple
        for prefix_key in prefix_keys:
            # Build nested dictionaries up for the current key tuple
            try:
                current = cast(NestedDict[KT, VT], current[prefix_key])
            except KeyError:
                new_dict = current[prefix_key] = new()
                current = new_dict

        # Set the value
        current[last_key] = value

    return result


T = TypeVar("T")


def isiterable(obj: Any) -> bool:
    """
    Return a boolean indicating if an object is iterable.

    Excludes types that are iterable but typically used as singletons:
    - str
    - bytes
    - IO objects
    """
    try:
        iter(obj)
    except TypeError:
        return False
    else:
        return not isinstance(obj, (str, bytes, io.IOBase))


def ensure_iterable(obj: Union[T, Iterable[T]]) -> Collection[T]:
    if isinstance(obj, Sequence) or isinstance(obj, Set):
        return cast(Collection[T], obj)
    obj = cast(T, obj)  # No longer in the iterable case
    return [obj]


def listrepr(objs: Iterable[Any], sep: str = " ") -> str:
    return sep.join(repr(obj) for obj in objs)


def extract_instances(
    objects: Iterable[Any],
    types: Union[type[T], tuple[type[T], ...]] = object,
) -> Union[list[T], dict[type[T], list[T]]]:
    """
    Extract objects from a file and returns a dict of type -> instances

    Args:
        objects: An iterable of objects
        types: A type or tuple of types to extract, defaults to all objects

    Returns:
        If a single type is given: a list of instances of that type
        If a tuple of types is given: a mapping of type to a list of instances
    """
    types_collection = ensure_iterable(types)

    # Create a mapping of type -> instance from the exec values
    ret: dict[type[T], list[Any]] = {}

    for o in objects:
        # We iterate here so that the key is the passed type rather than type(o)
        for type_ in types_collection:
            if isinstance(o, type_):
                ret.setdefault(type_, []).append(o)

    if len(types_collection) == 1:
        [type_] = types_collection
        return ret[type_]

    return ret


def batched_iterable(
    iterable: Iterable[T], size: int
) -> Generator[tuple[T, ...], None, None]:
    """
    Yield batches of a certain size from an iterable

    Args:
        iterable (Iterable): An iterable
        size (int): The batch size to return

    Yields:
        tuple: A batch of the iterable
    """
    it = iter(iterable)
    while True:
        batch = tuple(itertools.islice(it, size))
        if not batch:
            break
        yield batch


class StopVisiting(BaseException):
    """
    A special exception used to stop recursive visits in `visit_collection`.

    When raised, the expression is returned without modification and recursive visits
    in that path will end.
    """


@overload
def visit_collection(
    expr: Any,
    visit_fn: Callable[[Any, dict[str, VT]], Any],
    *,
    return_data: Literal[True] = ...,
    max_depth: int = ...,
    context: dict[str, VT] = ...,
    remove_annotations: bool = ...,
    _seen: Optional[dict[int, Any]] = ...,
) -> Any: ...


@overload
def visit_collection(
    expr: Any,
    visit_fn: Callable[[Any], Any],
    *,
    return_data: Literal[True] = ...,
    max_depth: int = ...,
    context: None = None,
    remove_annotations: bool = ...,
    _seen: Optional[dict[int, Any]] = ...,
) -> Any: ...


@overload
def visit_collection(
    expr: Any,
    visit_fn: Callable[[Any, dict[str, VT]], Any],
    *,
    return_data: bool = ...,
    max_depth: int = ...,
    context: dict[str, VT] = ...,
    remove_annotations: bool = ...,
    _seen: Optional[dict[int, Any]] = ...,
) -> Optional[Any]: ...


@overload
def visit_collection(
    expr: Any,
    visit_fn: Callable[[Any], Any],
    *,
    return_data: bool = ...,
    max_depth: int = ...,
    context: None = None,
    remove_annotations: bool = ...,
    _seen: Optional[dict[int, Any]] = ...,
) -> Optional[Any]: ...


@overload
def visit_collection(
    expr: Any,
    visit_fn: Callable[[Any, dict[str, VT]], Any],
    *,
    return_data: Literal[False] = False,
    max_depth: int = ...,
    context: dict[str, VT] = ...,
    remove_annotations: bool = ...,
    _seen: Optional[dict[int, Any]] = ...,
) -> None: ...


def visit_collection(
    expr: Any,
    visit_fn: Union[Callable[[Any, dict[str, VT]], Any], Callable[[Any], Any]],
    *,
    return_data: bool = False,
    max_depth: int = -1,
    context: Optional[dict[str, VT]] = None,
    remove_annotations: bool = False,
    _seen: Optional[dict[int, Any]] = None,
) -> Optional[Any]:
    """
    Visits and potentially transforms every element of an arbitrary Python collection.

    If an element is a Python collection, it will be visited recursively. If an element
    is not a collection, `visit_fn` will be called with the element. The return value of
    `visit_fn` can be used to alter the element if `return_data` is set to `True`.

    Note:
    - When `return_data` is `True`, a copy of each collection is created only if
      `visit_fn` modifies an element within that collection. This approach minimizes
      performance penalties by avoiding unnecessary copying.
    - When `return_data` is `False`, no copies are created, and only side effects from
      `visit_fn` are applied. This mode is faster and should be used when no transformation
      of the collection is required, because it never has to copy any data.

    Supported types:
    - List (including iterators)
    - Tuple
    - Set
    - Dict (note: keys are also visited recursively)
    - Dataclass
    - Pydantic model
    - Prefect annotations

    Note that visit_collection will not consume generators or async generators, as it would prevent
    the caller from iterating over them.

    Args:
        expr (Any): A Python object or expression.
        visit_fn (Callable[[Any, Optional[dict]], Any] or Callable[[Any], Any]): A function
            that will be applied to every non-collection element of `expr`. The function can
            accept one or two arguments. If two arguments are accepted, the second argument
            will be the context dictionary.
        return_data (bool): If `True`, a copy of `expr` containing data modified by `visit_fn`
            will be returned. This is slower than `return_data=False` (the default).
        max_depth (int): Controls the depth of recursive visitation. If set to zero, no
            recursion will occur. If set to a positive integer `N`, visitation will only
            descend to `N` layers deep. If set to any negative integer, no limit will be
            enforced and recursion will continue until terminal items are reached. By
            default, recursion is unlimited.
        context (Optional[dict]): An optional dictionary. If passed, the context will be sent
            to each call to the `visit_fn`. The context can be mutated by each visitor and
            will be available for later visits to expressions at the given depth. Values
            will not be available "up" a level from a given expression.
            The context will be automatically populated with an 'annotation' key when
            visiting collections within a `BaseAnnotation` type. This requires the caller to
            pass `context={}` and will not be activated by default.
        remove_annotations (bool): If set, annotations will be replaced by their contents. By
            default, annotations are preserved but their contents are visited.
        _seen (Optional[Set[int]]): A set of object ids that have already been visited. This
            prevents infinite recursion when visiting recursive data structures.

    Returns:
        Any: The modified collection if `return_data` is `True`, otherwise `None`.
    """

    if _seen is None:
        _seen = {}

    if context is not None:
        _callback = cast(Callable[[Any, dict[str, VT]], Any], visit_fn)

        def visit_nested(expr: Any) -> Optional[Any]:
            return visit_collection(
                expr,
                _callback,
                return_data=return_data,
                remove_annotations=remove_annotations,
                max_depth=max_depth - 1,
                # Copy the context on nested calls so it does not "propagate up"
                context=context.copy(),
                _seen=_seen,
            )

        def visit_expression(expr: Any) -> Any:
            return _callback(expr, context)
    else:
        _callback = cast(Callable[[Any], Any], visit_fn)

        def visit_nested(expr: Any) -> Optional[Any]:
            # Utility for a recursive call, preserving options and updating the depth.
            return visit_collection(
                expr,
                _callback,
                return_data=return_data,
                remove_annotations=remove_annotations,
                max_depth=max_depth - 1,
                _seen=_seen,
            )

        def visit_expression(expr: Any) -> Any:
            return _callback(expr)

    # --- 1. Visit every expression
    try:
        result = visit_expression(expr)
    except StopVisiting:
        max_depth = 0
        result = expr

    if return_data:
        # Only mutate the root expression if the user indicated we're returning data,
        # otherwise the function could return null and we have no collection to check
        expr = result

    # --- 2. Visit every child of the expression recursively

    # If we have reached the maximum depth or we have already visited this object,
    # return the result if we are returning data, otherwise return None
    obj_id = id(expr)
    if max_depth == 0:
        return result if return_data else None
    elif obj_id in _seen:
        # Return the cached transformed result
        return _seen[obj_id] if return_data else None

    # Mark this object as being processed to handle circular references
    # We'll update with the actual result later
    _seen[obj_id] = expr

    # Then visit every item in the expression if it is a collection

    # presume that the result is the original expression.
    # in each of the following cases, we will update the result if we need to.
    result = expr

    # --- Generators

    if isinstance(expr, (types.GeneratorType, types.AsyncGeneratorType)):
        # Do not attempt to iterate over generators, as it will exhaust them
        pass

    # --- Mocks

    elif isinstance(expr, Mock):
        # Do not attempt to recurse into mock objects
        pass

    # --- Annotations (unmapped, quote, etc.)

    elif isinstance(expr, BaseAnnotation):
        annotated = cast(BaseAnnotation[Any], expr)
        if context is not None:
            context["annotation"] = cast(VT, annotated)
        unwrapped = annotated.unwrap()
        value = visit_nested(unwrapped)

        if return_data:
            # if we are removing annotations, return the value
            if remove_annotations:
                result = value
            # if the value was modified, rewrap it
            elif value is not unwrapped:
                result = annotated.rewrap(value)
            # otherwise return the expr

    # --- Sequences

    elif isinstance(expr, (list, tuple, set)):
        seq = cast(Union[list[Any], tuple[Any], set[Any]], expr)
        items = [visit_nested(o) for o in seq]
        if return_data:
            modified = any(item is not orig for item, orig in zip(items, seq))
            if modified:
                result = type(seq)(items)

    # --- Dictionaries

    elif isinstance(expr, (dict, OrderedDict)):
        mapping = cast(dict[Any, Any], expr)
        items = [(visit_nested(k), visit_nested(v)) for k, v in mapping.items()]
        if return_data:
            modified = any(
                k1 is not k2 or v1 is not v2
                for (k1, v1), (k2, v2) in zip(items, mapping.items())
            )
            if modified:
                result = type(mapping)(items)

    # --- Dataclasses

    elif is_dataclass(expr) and not isinstance(expr, type):
        expr_fields = fields(expr)
        values = [visit_nested(getattr(expr, f.name)) for f in expr_fields]
        if return_data:
            modified = any(
                getattr(expr, f.name) is not v for f, v in zip(expr_fields, values)
            )
            if modified:
                result = replace(
                    expr, **{f.name: v for f, v in zip(expr_fields, values)}
                )

    # --- Pydantic models

    elif isinstance(expr, pydantic.BaseModel):
        # when extra=allow, fields not in model_fields may be in model_fields_set
        original_data = dict(expr)
        updated_data = {
            field: visit_nested(value) for field, value in original_data.items()
        }

        if return_data:
            modified = any(
                original_data[field] is not updated_data[field]
                for field in updated_data
            )
            if modified:
                # Use construct to avoid validation and handle immutability
                model_instance = expr.model_construct(
                    _fields_set=expr.model_fields_set, **updated_data
                )
                for private_attr in expr.__private_attributes__:
                    setattr(model_instance, private_attr, getattr(expr, private_attr))
                result = model_instance

    # Update the cache with the final transformed result
    if return_data:
        _seen[obj_id] = result

    if return_data:
        return result


@overload
def remove_nested_keys(
    keys_to_remove: list[HashableT], obj: NestedDict[HashableT, VT]
) -> NestedDict[HashableT, VT]: ...


@overload
def remove_nested_keys(keys_to_remove: list[HashableT], obj: Any) -> Any: ...


def remove_nested_keys(
    keys_to_remove: list[HashableT], obj: Union[NestedDict[HashableT, VT], Any]
) -> Union[NestedDict[HashableT, VT], Any]:
    """
    Recurses a dictionary returns a copy without all keys that match an entry in
    `key_to_remove`. Return `obj` unchanged if not a dictionary.

    Args:
        keys_to_remove: A list of keys to remove from obj obj: The object to remove keys
            from.

    Returns:
        `obj` without keys matching an entry in `keys_to_remove` if `obj` is a
            dictionary. `obj` if `obj` is not a dictionary.
    """
    if not isinstance(obj, dict):
        return obj
    return {
        key: remove_nested_keys(keys_to_remove, value)
        for key, value in cast(NestedDict[HashableT, VT], obj).items()
        if key not in keys_to_remove
    }


@overload
def distinct(
    iterable: Iterable[HashableT], key: None = None
) -> Iterator[HashableT]: ...


@overload
def distinct(iterable: Iterable[T], key: Callable[[T], Hashable]) -> Iterator[T]: ...


def distinct(
    iterable: Iterable[Union[T, HashableT]],
    key: Optional[Callable[[T], Hashable]] = None,
) -> Iterator[Union[T, HashableT]]:
    def _key(__i: Any) -> Hashable:
        return __i

    if key is not None:
        _key = cast(Callable[[Any], Hashable], key)

    seen: set[Hashable] = set()
    for item in iterable:
        if _key(item) in seen:
            continue
        seen.add(_key(item))
        yield item


@overload
def get_from_dict(
    dct: NestedDict[str, VT], keys: Union[str, list[str]], default: None = None
) -> Optional[VT]: ...


@overload
def get_from_dict(
    dct: NestedDict[str, VT], keys: Union[str, list[str]], default: R
) -> Union[VT, R]: ...


def get_from_dict(
    dct: NestedDict[str, VT], keys: Union[str, list[str]], default: Optional[R] = None
) -> Union[VT, R, None]:
    """
    Fetch a value from a nested dictionary or list using a sequence of keys.

    This function allows to fetch a value from a deeply nested structure
    of dictionaries and lists using either a dot-separated string or a list
    of keys. If a requested key does not exist, the function returns the
    provided default value.

    Args:
        dct: The nested dictionary or list from which to fetch the value.
        keys: The sequence of keys to use for access. Can be a
            dot-separated string or a list of keys. List indices can be included
            in the sequence as either integer keys or as string indices in square
            brackets.
        default: The default value to return if the requested key path does not
            exist. Defaults to None.

    Returns:
        The fetched value if the key exists, or the default value if it does not.

    Examples:

    ```python
    get_from_dict({'a': {'b': {'c': [1, 2, 3, 4]}}}, 'a.b.c[1]') # 2
    get_from_dict({'a': {'b': [0, {'c': [1, 2]}]}}, ['a', 'b', 1, 'c', 1]) # 2
    get_from_dict({'a': {'b': [0, {'c': [1, 2]}]}}, 'a.b.1.c.2', 'default') # 'default'
    ```
    """
    if isinstance(keys, str):
        keys = keys.replace("[", ".").replace("]", "").split(".")
    value = dct
    try:
        for key in keys:
            try:
                # Try to cast to int to handle list indices
                key = int(key)
            except ValueError:
                # If it's not an int, use the key as-is
                # for dict lookup
                pass
            value = value[key]  # type: ignore
        return cast(VT, value)
    except (TypeError, KeyError, IndexError):
        return default


def set_in_dict(
    dct: NestedDict[str, VT], keys: Union[str, list[str]], value: VT
) -> None:
    """
    Sets a value in a nested dictionary using a sequence of keys.

    This function allows to set a value in a deeply nested structure
    of dictionaries and lists using either a dot-separated string or a list
    of keys. If a requested key does not exist, the function will create it as
    a new dictionary.

    Args:
        dct: The dictionary to set the value in.
        keys: The sequence of keys to use for access. Can be a
            dot-separated string or a list of keys.
        value: The value to set in the dictionary.

    Returns:
        The modified dictionary with the value set at the specified key path.

    Raises:
        KeyError: If the key path exists and is not a dictionary.
    """
    if isinstance(keys, str):
        keys = keys.replace("[", ".").replace("]", "").split(".")
    for k in keys[:-1]:
        if not isinstance(dct.get(k, {}), dict):
            raise TypeError(f"Key path exists and contains a non-dict value: {keys}")
        if k not in dct:
            dct[k] = {}
        dct = cast(NestedDict[str, VT], dct[k])
    dct[keys[-1]] = value


def deep_merge(
    dct: NestedDict[str, VT1], merge: NestedDict[str, VT2]
) -> NestedDict[str, Union[VT1, VT2]]:
    """
    Recursively merges `merge` into `dct`.

    Args:
        dct: The dictionary to merge into.
        merge: The dictionary to merge from.

    Returns:
        A new dictionary with the merged contents.
    """
    result: dict[str, Any] = dct.copy()  # Start with keys and values from `dct`
    for key, value in merge.items():
        if key in result and isinstance(result[key], dict) and isinstance(value, dict):
            # If both values are dictionaries, merge them recursively
            result[key] = deep_merge(
                cast(NestedDict[str, VT1], result[key]),
                cast(NestedDict[str, VT2], value),
            )
        else:
            # Otherwise, overwrite with the new value
            result[key] = cast(Union[VT2, NestedDict[str, VT2]], value)
    return result


def deep_merge_dicts(*dicts: NestedDict[str, Any]) -> NestedDict[str, Any]:
    """
    Recursively merges multiple dictionaries.

    Args:
        dicts: The dictionaries to merge.

    Returns:
        A new dictionary with the merged contents.
    """
    result: NestedDict[str, Any] = {}
    for dictionary in dicts:
        result = deep_merge(result, dictionary)
    return result
